{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🤪 Variational Autoencoders - CelebA Faces"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own variational autoencoder on the CelebA faces dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.backend as K\n",
    "from tensorflow.keras import (\n",
    "    layers,\n",
    "    models,\n",
    "    callbacks,\n",
    "    utils,\n",
    "    metrics,\n",
    "    losses,\n",
    "    optimizers,\n",
    ")\n",
    "\n",
    "from scipy.stats import norm\n",
    "import pandas as pd\n",
    "\n",
    "from notebooks.utils import sample_batch, display\n",
    "\n",
    "from vae_utils import get_vector_from_label, add_vector_to_images, morph_faces"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 32\n",
    "CHANNELS = 3\n",
    "BATCH_SIZE = 128\n",
    "NUM_FEATURES = 128\n",
    "Z_DIM = 200\n",
    "LEARNING_RATE = 0.0005\n",
    "EPOCHS = 10\n",
    "BETA = 2000\n",
    "LOAD_MODEL = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7716fac-0010-49b0-b98e-53be2259edde",
   "metadata": {},
   "source": [
    "## 1. Prepare the data <a name=\"prepare\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "train_data = utils.image_dataset_from_directory(\n",
    "    \"/app/data/celeba-dataset/img_align_celeba/img_align_celeba\",\n",
    "    labels=None,\n",
    "    color_mode=\"rgb\",\n",
    "    image_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    seed=42,\n",
    "    interpolation=\"bilinear\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebae2f0d-59fd-4796-841f-7213eae638de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the data\n",
    "def preprocess(img):\n",
    "    img = tf.cast(img, \"float32\") / 255.0\n",
    "    return img\n",
    "\n",
    "\n",
    "train = train_data.map(lambda x: preprocess(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b03f32fd-addb-4c9b-906c-a5f1934df7e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sample = sample_batch(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show some faces from the training set\n",
    "display(train_sample, cmap=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2. Build the variational autoencoder <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a0625b6-3c19-478b-84f9-5c2b5c2b74b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Sampling(layers.Layer):\n",
    "    def call(self, inputs):\n",
    "        z_mean, z_log_var = inputs\n",
    "        batch = tf.shape(z_mean)[0]\n",
    "        dim = tf.shape(z_mean)[1]\n",
    "        epsilon = K.random_normal(shape=(batch, dim))\n",
    "        return z_mean + tf.exp(0.5 * z_log_var) * epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086e2584-c60d-4990-89f4-2092c44e023e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encoder\n",
    "encoder_input = layers.Input(\n",
    "    shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS), name=\"encoder_input\"\n",
    ")\n",
    "x = layers.Conv2D(NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\")(\n",
    "    encoder_input\n",
    ")\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Conv2D(NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\")(x)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Conv2D(NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\")(x)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Conv2D(NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\")(x)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "shape_before_flattening = K.int_shape(x)[1:]  # the decoder will need this!\n",
    "\n",
    "x = layers.Flatten()(x)\n",
    "z_mean = layers.Dense(Z_DIM, name=\"z_mean\")(x)\n",
    "z_log_var = layers.Dense(Z_DIM, name=\"z_log_var\")(x)\n",
    "z = Sampling()([z_mean, z_log_var])\n",
    "\n",
    "encoder = models.Model(encoder_input, [z_mean, z_log_var, z], name=\"encoder\")\n",
    "encoder.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88010f20-fb61-498c-b2b2-dac96f6c03b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decoder\n",
    "decoder_input = layers.Input(shape=(Z_DIM,), name=\"decoder_input\")\n",
    "x = layers.Dense(np.prod(shape_before_flattening))(decoder_input)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Reshape(shape_before_flattening)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\"\n",
    ")(x)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\"\n",
    ")(x)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\"\n",
    ")(x)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    NUM_FEATURES, kernel_size=3, strides=2, padding=\"same\"\n",
    ")(x)\n",
    "x = layers.BatchNormalization()(x)\n",
    "x = layers.LeakyReLU()(x)\n",
    "decoder_output = layers.Conv2DTranspose(\n",
    "    CHANNELS, kernel_size=3, strides=1, activation=\"sigmoid\", padding=\"same\"\n",
    ")(x)\n",
    "decoder = models.Model(decoder_input, decoder_output)\n",
    "decoder.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4ad9761-9756-45b3-83ef-ee3d9218d694",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(models.Model):\n",
    "    def __init__(self, encoder, decoder, **kwargs):\n",
    "        super(VAE, self).__init__(**kwargs)\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.total_loss_tracker = metrics.Mean(name=\"total_loss\")\n",
    "        self.reconstruction_loss_tracker = metrics.Mean(\n",
    "            name=\"reconstruction_loss\"\n",
    "        )\n",
    "        self.kl_loss_tracker = metrics.Mean(name=\"kl_loss\")\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [\n",
    "            self.total_loss_tracker,\n",
    "            self.reconstruction_loss_tracker,\n",
    "            self.kl_loss_tracker,\n",
    "        ]\n",
    "\n",
    "    def call(self, inputs):\n",
    "        \"\"\"Call the model on a particular input.\"\"\"\n",
    "        z_mean, z_log_var, z = encoder(inputs)\n",
    "        reconstruction = decoder(z)\n",
    "        return z_mean, z_log_var, reconstruction\n",
    "\n",
    "    def train_step(self, data):\n",
    "        \"\"\"Step run during training.\"\"\"\n",
    "        with tf.GradientTape() as tape:\n",
    "            z_mean, z_log_var, reconstruction = self(data, training=True)\n",
    "            reconstruction_loss = tf.reduce_mean(\n",
    "                BETA * losses.mean_squared_error(data, reconstruction)\n",
    "            )\n",
    "            kl_loss = tf.reduce_mean(\n",
    "                tf.reduce_sum(\n",
    "                    -0.5\n",
    "                    * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),\n",
    "                    axis=1,\n",
    "                )\n",
    "            )\n",
    "            total_loss = reconstruction_loss + kl_loss\n",
    "\n",
    "        grads = tape.gradient(total_loss, self.trainable_weights)\n",
    "        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
    "\n",
    "        self.total_loss_tracker.update_state(total_loss)\n",
    "        self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n",
    "        self.kl_loss_tracker.update_state(kl_loss)\n",
    "\n",
    "        return {\n",
    "            \"loss\": self.total_loss_tracker.result(),\n",
    "            \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n",
    "            \"kl_loss\": self.kl_loss_tracker.result(),\n",
    "        }\n",
    "\n",
    "    def test_step(self, data):\n",
    "        \"\"\"Step run during validation.\"\"\"\n",
    "        if isinstance(data, tuple):\n",
    "            data = data[0]\n",
    "\n",
    "        z_mean, z_log_var, reconstruction = self(data)\n",
    "        reconstruction_loss = tf.reduce_mean(\n",
    "            BETA * losses.mean_squared_error(data, reconstruction)\n",
    "        )\n",
    "        kl_loss = tf.reduce_mean(\n",
    "            tf.reduce_sum(\n",
    "                -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),\n",
    "                axis=1,\n",
    "            )\n",
    "        )\n",
    "        total_loss = reconstruction_loss + kl_loss\n",
    "\n",
    "        return {\n",
    "            \"loss\": total_loss,\n",
    "            \"reconstruction_loss\": reconstruction_loss,\n",
    "            \"kl_loss\": kl_loss,\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edf2f892-9209-42ee-b251-1e7604df5335",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a variational autoencoder\n",
    "vae = VAE(encoder, decoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 3. Train the variational autoencoder <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b429fdad-ea9c-45a2-a556-eb950d793824",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile the variational autoencoder\n",
    "optimizer = optimizers.Adam(learning_rate=LEARNING_RATE)\n",
    "vae.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a model save checkpoint\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint\",\n",
    "    save_weights_only=False,\n",
    "    save_freq=\"epoch\",\n",
    "    monitor=\"loss\",\n",
    "    mode=\"min\",\n",
    "    save_best_only=True,\n",
    "    verbose=0,\n",
    ")\n",
    "\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_img, latent_dim):\n",
    "        self.num_img = num_img\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        random_latent_vectors = tf.random.normal(\n",
    "            shape=(self.num_img, self.latent_dim)\n",
    "        )\n",
    "        generated_images = self.model.decoder(random_latent_vectors)\n",
    "        generated_images *= 255\n",
    "        generated_images.numpy()\n",
    "        for i in range(self.num_img):\n",
    "            img = utils.array_to_img(generated_images[i])\n",
    "            img.save(\"./output/generated_img_%03d_%d.png\" % (epoch, i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d080d9a6-8f53-4984-9f80-5e139e6c8d4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load old weights if required\n",
    "if LOAD_MODEL:\n",
    "    vae.load_weights(\"./models/vae\")\n",
    "    tmp = vae.predict(train.take(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c497b7-fa40-48df-b2bf-541239cc9400",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "vae.fit(\n",
    "    train,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[\n",
    "        model_checkpoint_callback,\n",
    "        tensorboard_callback,\n",
    "        ImageGenerator(num_img=10, latent_dim=Z_DIM),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028138af-d3a5-4134-b980-d3a8a703e70f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the final models\n",
    "vae.save(\"./models/vae\")\n",
    "encoder.save(\"./models/encoder\")\n",
    "decoder.save(\"./models/decoder\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "682fb369-33fb-4f16-a601-47db56de3fd2",
   "metadata": {},
   "source": [
    "## 3. Reconstruct using the variational autoencoder <a name=\"reconstruct\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1fece5-77a8-4510-be7d-713cc08aee37",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select a subset of the test set\n",
    "batches_to_predict = 1\n",
    "example_images = np.array(\n",
    "    list(train.take(batches_to_predict).get_single_element())\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db7fba06-6a5f-49c2-82a7-e6265acf1477",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create autoencoder predictions and display\n",
    "z_mean, z_log_var, reconstructions = vae.predict(example_images)\n",
    "print(\"Example real faces\")\n",
    "display(example_images)\n",
    "print(\"Reconstructions\")\n",
    "display(reconstructions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11c06dcb-cd9c-4784-93f8-0cff7002cf49",
   "metadata": {},
   "source": [
    "## 4. Latent space distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "465d6359-486a-457a-a598-a2be6fffa16f",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, z = vae.encoder.predict(example_images)\n",
    "\n",
    "x = np.linspace(-3, 3, 100)\n",
    "\n",
    "fig = plt.figure(figsize=(20, 5))\n",
    "fig.subplots_adjust(hspace=0.6, wspace=0.4)\n",
    "\n",
    "for i in range(50):\n",
    "    ax = fig.add_subplot(5, 10, i + 1)\n",
    "    ax.hist(z[:, i], density=True, bins=20)\n",
    "    ax.axis(\"off\")\n",
    "    ax.text(\n",
    "        0.5, -0.35, str(i), fontsize=10, ha=\"center\", transform=ax.transAxes\n",
    "    )\n",
    "    ax.plot(x, norm.pdf(x))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfa68340-5b03-4307-8a0f-2fe2d1658846",
   "metadata": {},
   "source": [
    "## 5. Generate new faces <a name=\"decode\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8189d44f-7b4c-4720-ab79-499cb587202e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample some points in the latent space, from the standard normal distribution\n",
    "grid_width, grid_height = (10, 3)\n",
    "z_sample = np.random.normal(size=(grid_width * grid_height, Z_DIM))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b9043c1-53dd-430b-9d84-5bf08f773f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decode the sampled points\n",
    "reconstructions = decoder.predict(z_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c50c2b79-0d15-4450-bcc5-77b6891122bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw a plot of decoded images\n",
    "fig = plt.figure(figsize=(18, 5))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "\n",
    "# Output the grid of faces\n",
    "for i in range(grid_width * grid_height):\n",
    "    ax = fig.add_subplot(grid_height, grid_width, i + 1)\n",
    "    ax.axis(\"off\")\n",
    "    ax.imshow(reconstructions[i, :, :])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c303cf9-9799-45cc-9315-4623fc0f20e6",
   "metadata": {},
   "source": [
    "## 6. Manipulate the images <a name=\"manipulate\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "819635b7-ab3c-4a80-83ef-ce00f0696b46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the label dataset\n",
    "attributes = pd.read_csv(\"/app/data/celeba-dataset/list_attr_celeba.csv\")\n",
    "print(attributes.columns)\n",
    "attributes.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f319d66-293b-4f38-8744-9e1dda150ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the face data with label attached\n",
    "LABEL = \"Blond_Hair\"  # <- Set this label\n",
    "labelled_test = utils.image_dataset_from_directory(\n",
    "    \"/app/data/celeba-dataset/img_align_celeba\",\n",
    "    labels=attributes[LABEL].tolist(),\n",
    "    color_mode=\"rgb\",\n",
    "    image_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    seed=42,\n",
    "    validation_split=0.2,\n",
    "    subset=\"validation\",\n",
    "    interpolation=\"bilinear\",\n",
    ")\n",
    "\n",
    "labelled = labelled_test.map(lambda x, y: (preprocess(x), y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fed6cc52-e0f7-465e-a197-8949dd9fbe82",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the attribute vector\n",
    "attribute_vec = get_vector_from_label(labelled, vae, Z_DIM, LABEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88630224-8067-4510-86a0-2646068a4db7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add vector to images\n",
    "add_vector_to_images(labelled, vae, attribute_vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd8f22fe-9601-4461-9473-72cb4ef80bf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "morph_faces(labelled, vae)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
